import torch.nn.functional as F
import torch

if __name__ == '__main__':
    x = torch.randn(3, 4, 3)
    print(torch.sum(x, 2))

